Fix np bfloat16 misinterpreted as complex#3146
Conversation
|
thanks for the feedback! |
|
I pushed a change to reuse |
|
Thanks for the help, this builds and runs fine on my machine (the only tests I had were the ones in Also, I really enjoyed working on the binding and types for this PR. I'm looking for another issue to pick up, but the public ones seem pretty sparse right now. I was wondering if the mlx team had any feature buildouts targeted (maybe internally) that are open to community contribution? Perhaps I could work on something there? |
|
Thanks for your interests! There are some python related issues that I would be happy to review contributions: |
Co-authored-by: Cheng <git@zcbenz.com>
Proposed changes
Fixes #1075
Bug: The bug happens when converting
np.array(1., dtype=ml_dtypes.bfloat16)andnp.array([1.], dtype=ml_dtypes.bfloat16)tomx.array(x). For the former case, it'll silently be caught as astd::complexas part ofArrayInitTypeand get converted as such (see related issue for code). For the latter, it'll be interpreted as anArrayLike, not be able to make the conversion tomx.array()and raise aValueError.The Fix: We need to catch this case before it gets filtered by
ArrayInitType. I made thearray.__init__more generic to catch this and checked the dtype to match bfloat16, then manually construct the array. Otherwise, we fallback to the originalArrayInitTypecase.Note: bfloat16, is the only current ml_dtype that mlx supports.
Verification: Verified locally (macOS 26.2, MLX 0.30.7, Apple M2) with the additional test case I provided. If this is run from main, it raises the bug mentioned above:
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes